Skip to content

Conversation

benchislett
Copy link

@benchislett benchislett commented Sep 8, 2025

What does this PR do?

Type of change: Feature support for offline training of EAGLE3 heads using the HuggingFace training script

Overview:

This PR contains two primary components:

  • New helper scripts, ported from previous EAGLE3 training scripts, to facilitate easy data preparation for offline training. These include:

    • prepare_input_conversations: short python scripts to load commonly-used training datasets and output a standardized jsonl dataset file that is ready for training
    • gen_synthetic_conversations: scripts to generate synthetic conversations using a conversation dataset as a collection of prompts for the model. Currently, OpenAI-compatible endpoints are used to generate conversations. This is a known bottleneck of dataset preparation so high performance is key. Any serving engine can be used, but an example script demonstrates how to launch vllm for inference.
    • collect_hidden_states: scripts to extract and save hidden states generated from a conversation dataset, for use in offline training. These include a script that can send the completion requests to a local inference server running with a patched inference loop, as well as a script that uses HF transformers AutoModel to generate the hidden states explicitly. Using an inference server is more performant, but either will work well for generating hidden states.
  • Support for offline training in the EAGLE3 training scripts. This is triggered by sending --offline-data X to launch.sh, which will then launch main.py with: --offline-training True --offline-data-path $OFFLINE_DATA_PATH --omit-target-layers False. See below for example usage.

Currently, the inline evaluation using ar_validate.py does not work when the target model's hidden layers are deleted, so the memory footprint of the output checkpoints are not smaller than using online training. However, all performance gains during training should still be present. --omit-target-layers controls this behaviour and will be re-enabled when offline support is added to the validation script.

Usage

Here are the steps to reproduce an offline training run of Llama 3.2 1B-Instruct on the Daring-Anteater dataset:

# First, prepare the conversations
python3 make_prompts_for_gen/add_daring_anteater.py

# Then generate synthetic conversations
vllm serve meta-llama/Llama-3.2-1B-Instruct # launch any llm server
bash gen_synthetic_conversations/send_completion_reqs_openai.sh
# don't forget to shutdown the llm server

# Compute hidden states
bash collect_hidden_states/run_hf_compute_hiddens.sh

# Launch training
OUTPUT_DIR=ckpts/${llama1b}-$(date +%Y%m%d_%H%M)
mkdir -p "$(dirname "$OUTPUT_DIR")"
./launch.sh --model meta-llama/Llama-3.2-1B-Instruct \
            --output_dir $OUTPUT_DIR \
            --offline-data /mnt/md0/eagle-hidden-states/llama1b/daring_anteater/ \
            --data synthetic_conversations/daring-anteater.jsonl \
            --num_gpu 8 \
            --num_epochs 2 \
            --eagle_config eagle_config.json

Note that all commands are expected to be running with modelopt installed, and from the base directory at examples/speculative_decoding

Testing

Training was tested and evaluated on the example setup above, reporting ~2.2 AL after 2 epochs in all cases. Offline and online training produced nearly identical loss curves and acceptance rates at each evaluation.

  • Is this change backward compatible?: Yes. Online training should not be affected in any way.
  • Did you write any new necessary tests?: No. Unsure if we have unit tests for training
  • Did you add or update any necessary documentation?: No. Waiting for Feat: update eagle3 example; add export #293 to land before updating any READMEs / docs.
  • Did you update Changelog?: No. TODO.

Summary by CodeRabbit

  • New Features

    • Offline training support with precomputed hidden states, new data-path flag, and default sequence length increased to 2048.
    • Dataset preparation tools for Daring-Anteater, MTBench, ShareGPT, UltraChat, plus utilities to mix/manage splits.
    • Hidden-state workflows and utilities: compute, send, sample, and multi-GPU runner scripts; model updated to support offline paths.
  • Chores

    • Expanded .gitignore to ignore input_conversations, synthetic_conversations, and ckpts; added example dataset-builder script; removed legacy launcher.

Copy link

copy-pr-bot bot commented Sep 8, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Copy link

coderabbitai bot commented Sep 8, 2025

Walkthrough

Adds dataset preparation utilities, tools to compute/send/sample per-conversation hidden states, an offline training path that consumes precomputed .pt hidden-states (dataset/collator/dataloader changes and CLI/launcher wiring), transformer plugin adjustments for offline execution, .gitignore updates, and removal of a legacy launcher.

Changes

Cohort / File(s) Summary
Ignore updates
examples/speculative_decoding/.gitignore
Appends ignore patterns: input_conversations, synthetic_conversations, and ckpts.
Hidden-state collection package & scripts
examples/speculative_decoding/collect_hidden_states/__init__.py, .../compute_hidden_states_hf.py, .../sample_hidden_states.py, .../send_conversations_for_hiddens.py, .../run_hf_compute_hiddens.sh, .../run_hf_compute_hiddens_dp.sh, .../run_send_conversations.sh
New package and CLIs/scripts to compute (HF), send (OpenAI-compatible), sample, and orchestrate collection of per-conversation hidden states; produce per-conversation .pt artifacts containing input_ids, hidden_states, aux_hidden_states, and conversation_id.
Prepare input conversations utilities & scripts
examples/speculative_decoding/prepare_input_conversations/__init__.py, .../utils.py, .../add_daring_anteater.py, .../add_mtbench.py, .../add_sharegpt.py, .../add_ultrachat.py, .../example_make_prompt_dataset.sh
New utilities for downloading/parsing datasets, stable conversation ID hashing, deduplication and append-to-splits, mixing/splitting strategies, and dataset-conversion scripts for Daring-Anteater, MTBench, ShareGPT, UltraChat, plus an example dataset-build script.
Offline training integration (EAGLE)
examples/speculative_decoding/eagle_utils.py, examples/speculative_decoding/main.py, examples/speculative_decoding/launch_train.sh, examples/speculative_decoding/train_eagle3_and_export.sh
Adds OfflineSupervisedDataset, DataCollatorForOffline, extends make_eagle_supervised_data_module(..., use_offline_training), adds DataArguments.offline_data_path, launcher flags/validation and wiring, adjusts model loading/config for offline (e.g., store original hidden-layer count, set eagle_offline).
Transformer plugin updates
modelopt/torch/speculative/plugins/transformers.py
Registers HFEagleModel with offline registry, renames forwarded arg to past_key_values, read num_orig_hidden_layers when eagle_offline, prefer lm_head device for offline, use DynamicCache() as dummy past_key_values when base_model_outputs present, and remove DetachedHFEagleModel.
Removed legacy launcher
examples/speculative_decoding/launch.sh
Deleted legacy bash launcher that previously orchestrated speculative-decoding experiments.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant U as User
  participant Prep as Prepare scripts
  participant HF as compute_hidden_states_hf
  participant API as OpenAI-like server
  participant DS as Disk (.pt)
  participant Trainer as Trainer (offline)

  U->>Prep: run add_* → produce input_conversations JSONL
  alt local HF compute
    U->>HF: compute_hidden_states_hf (--model, --input-file)
    HF->>DS: write per-conversation .pt (input_ids, hidden_states, aux_hidden_states)
  else send to server
    U->>API: send_conversations_for_hiddens (--base-url, --input-file)
    API->>DS: server writes per-conversation .pt
  end
  U->>Trainer: launch_train.sh --offline-data PATH
  Trainer->>DS: read .pt files
  Trainer->>Trainer: OfflineSupervisedDataset + DataCollatorForOffline → training with eagle_offline=true
Loading
sequenceDiagram
  autonumber
  participant C as send_conversations_for_hiddens
  participant Tok as Tokenizer
  participant Meta as /tmp/meta.json
  participant API as OpenAI-like endpoint
  participant Out as Output Dir

  C->>Tok: apply_chat_template(conversations) -> input_ids / prompt
  C->>Meta: write {conversation_id, output_file}
  C->>API: completions.create(model, prompt, max_tokens=1)
  alt success
    API-->>C: 200 OK
    Note right of API: serving engine coordinates writing .pt to Out
    API->>Out: .pt created
  else error / too long
    API-->>C: error
    C->>C: increment counters, skip
  end
  C->>Meta: cleanup entry
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Poem

A whisker twitch, I gather seeds of text,
I hop through prompts, keep hidden thoughts indexed.
In little .pt burrows, tokens sleep tight,
Offline carrots glow through training-night.
Hooray — new paths, more hops, and snacks in sight! 🥕🐇

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title "Feature: Offline training for EAGLE3" succinctly and accurately captures the primary change in the pull request—adding offline training support for EAGLE3 (dataset prep, hidden-state collection, CLI flags, and model adjustments). It is specific, concise, and clearly related to the changeset contents.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch bchislett/offline-eagle-training

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f74bf59 and bd27b86.

📒 Files selected for processing (1)
  • examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: linux
  • GitHub Check: code-quality
  • GitHub Check: build-docs

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.


# Extend the data sample with the hidden states from the .pt file
max_length = self.tokenizer.model_max_length
offline_data = torch.load(offline_file_path)
Copy link
Contributor

@h-guo18 h-guo18 Sep 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious:

  1. With current implementation, will there by any pre-fetching mechanism to this tensor loading (perhaps taken care inside HF trainer)?
  2. Considering there's a limit in total disk bandwidth, will data loading possibly be a bottleneck limiting the training speed (if we further optimize the training loop, e.g. the TTT part)?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Yes, I believe standard pytorch/HF datasets will handle prefetching, multiple loading worker processes, etc.
  2. Indeed, disk bandwidth can absolutely bottleneck training. If this becomes an issue, we can offset this using techniques such as compressed hidden-states on-disk, or using a file system with faster read speeds (e.g. using many smaller disks for high parallel read throughput). However, the speed-of-light for even a single moderate-quality disk is quite good, so networking a few 1-4TB disks together can easily saturate the (optimized) GPU throughput.

@yeyu-nvidia
Copy link
Contributor

Can we remove gen_synthetic_conversations from the PR as we have https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/main/examples/speculative_decoding/distributed_generate

@benchislett benchislett force-pushed the bchislett/offline-eagle-training branch 2 times, most recently from 4282dbe to f6cb37a Compare September 15, 2025 17:13
Copy link

codecov bot commented Sep 15, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 73.82%. Comparing base (682bf6d) to head (bd27b86).

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #300      +/-   ##
==========================================
- Coverage   73.82%   73.82%   -0.01%     
==========================================
  Files         172      172              
  Lines       17438    17438              
==========================================
- Hits        12874    12873       -1     
- Misses       4564     4565       +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@h-guo18
Copy link
Contributor

h-guo18 commented Sep 15, 2025

LGTM. Tried offline workflow with tinyllama and got reasonable AR.

@h-guo18 h-guo18 self-requested a review September 16, 2025 18:19
@benchislett benchislett force-pushed the bchislett/offline-eagle-training branch from 1839ab5 to f92be76 Compare September 16, 2025 19:23
@benchislett benchislett marked this pull request as ready for review September 16, 2025 19:24
@benchislett benchislett requested a review from a team as a code owner September 16, 2025 19:24
@benchislett
Copy link
Author

/ok to test f92be76

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 9

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
examples/speculative_decoding/launch_train.sh (1)

144-170: --multi_gpu is passed to accelerate instead of main.py (likely unrecognized).

The flag appears intended for main.py but is placed before main.py, so accelerate may reject it. Move it after main.py.

-export TOKENIZERS_PARALLELISM=False
-CMD="accelerate launch $MULTI_GPU --mixed_precision bf16 main.py \
+export TOKENIZERS_PARALLELISM=False
+CMD="accelerate launch --mixed_precision bf16 main.py \
+    $MULTI_GPU \
     --mode $MODE \
     --model_name_or_path $MODEL \
     --training_seq_len $TRAINING_SEQ_LEN \
@@
-    --data_path $DATA \
+    ${DATA_ARGS} \
     $OFFLINE_TRAINING_ARGS \
     $SPECULATIVE_ARGS
 "

Add DATA_ARGS a few lines above (see next comment).

modelopt/torch/speculative/plugins/transformers.py (2)

120-132: Typo: rcache_position should be cache_position.

This will raise an unexpected keyword error in HF models.

-                rcache_position=cache_position,
+                cache_position=cache_position,

310-318: LlamaDecoderLayer expects past_key_value (singular), not past_key_values.

HF 4.48–4.57 use past_key_value on the decoder layer; passing past_key_values will error. Keep model-level plural, layer-level singular.

-                past_key_values=past_key_values,
+                past_key_value=past_key_values,

If you need to support future HF where the layer accepts plural, add a small shim with signature inspection and forward the appropriate kwarg.

♻️ Duplicate comments (3)
modelopt/torch/speculative/plugins/transformers.py (1)

811-813: Good: dummy DynamicCache for offline + legacy cache adapter.

This addresses earlier Cache init issues across transformers versions.

Also applies to: 827-828

examples/speculative_decoding/main.py (1)

187-194: Thanks for wiring eagle_offline into the config.

This addresses prior feedback to use EagleConfig’s flag. LGTM.

examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py (1)

143-143: Document the serving engine requirement inline.

Add a short comment pointing to the patched vLLM branch or required changes so users aren’t blocked. This addresses the earlier reviewer question.

🧹 Nitpick comments (33)
examples/speculative_decoding/.gitignore (1)

2-4: Ignore additional generated artifacts (export, hidden-states dirs).

Add common outputs created by the new workflow so they don’t get committed accidentally.

 Daring-Anteater
 input_conversations
 synthetic_conversations
 ckpts
+export
+hidden_states
examples/speculative_decoding/launch_train.sh (3)

91-95: Guard divide-by-zero when no GPUs are visible.

If torch reports 0 GPUs, DEFAULT_SAVE_STEPS errors. Provide a sane fallback.

 GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())")
-# Calculate save_steps
-DEFAULT_SAVE_STEPS=$((8192 / GPU_COUNT))
+# Calculate save_steps with fallback
+if [[ "${GPU_COUNT:-0}" -gt 0 ]]; then
+  DEFAULT_SAVE_STEPS=$((8192 / GPU_COUNT))
+else
+  echo "Warning: No CUDA devices found. Falling back to GPU_COUNT=1 for save_steps."
+  DEFAULT_SAVE_STEPS=8192
+fi

107-109: Remove unused variables (REDRAFTER_*).

They’re set but never used in this script.

-REDRAFTER_TOKENS=${REDRAFTER_TOKENS:-1}
-REDRAFTER_NUM_LAYERS=${REDRAFTER_NUM_LAYERS:-1}

81-84: Error message shows only the value after '='.

For bad args, print the full token for clarity.

-      >&2 printf "Error: Invalid argument ${1#*=}\n"
+      >&2 printf "Error: Invalid argument: %s\n" "$1"
examples/speculative_decoding/train_eagle3_and_export.sh (2)

61-65: Quote OFFLINE_DATA_PATH and validate early.

Preempt path issues and early-fail on typos.

-if [[ "$OFFLINE_DATA_PATH" != "" ]]; then
-  OFFLINE_DATA_ARGS="--offline-data $OFFLINE_DATA_PATH"
+if [[ -n "$OFFLINE_DATA_PATH" ]]; then
+  if [[ ! -d "$OFFLINE_DATA_PATH" ]]; then
+    echo "Offline data path not found: $OFFLINE_DATA_PATH" >&2
+    exit 1
+  fi
+  OFFLINE_DATA_ARGS="--offline-data \"$OFFLINE_DATA_PATH\""
 else
   OFFLINE_DATA_ARGS=""
 fi

72-79: Quote all interpolations in the training call.

Prevents breakage with spaces/special chars.

-./launch_train.sh --model $BASE_MODEL \
-            --output_dir $OUTPUT_DIR \
-            $OFFLINE_DATA_ARGS \
-            --data $DATA \
-            --num_gpu $NUM_GPU \
+./launch_train.sh --model "$BASE_MODEL" \
+            --output_dir "$OUTPUT_DIR" \
+            $OFFLINE_DATA_ARGS \
+            --data "$DATA" \
+            --num_gpu "$NUM_GPU" \
             --num_epochs 2 \
             --eagle_config eagle_config.json
examples/speculative_decoding/collect_hidden_states/sample_hidden_states.py (3)

63-63: Load on CPU to avoid accidental GPU allocations.

These artifacts are saved on CPU; loading on CPU is safer and avoids device mismatches.

-        data = torch.load(file)
+        data = torch.load(file, map_location="cpu")

50-55: Validate input path exists and is a dir/file.

Currently a non-existent path silently yields 0 files; fail fast.

-    if args.input_path.is_file():
+    if not args.input_path.exists():
+        raise FileNotFoundError(f"Input path not found: {args.input_path}")
+    if args.input_path.is_file():
         all_files = [args.input_path]

70-73: Allow superset of keys for forward-compat.

Future producers may include extra fields (e.g., metadata). Check for required keys instead of exact match.

-        if set(expected_keys) != set(data.keys()):
+        if not set(expected_keys).issubset(set(data.keys())):
             print(f"File {file} does not contain all expected keys: {expected_keys}")
             print(f"  Found keys: {list(data.keys())}")
             continue
modelopt/torch/speculative/plugins/transformers.py (1)

425-435: Device selection logic for offline mode — LGTM with a tiny robustness nit.

Using lm_head device in offline mode is correct. Consider guarding for models without lm_head (edge adapters).

-        if eagle_offline:
+        if eagle_offline:
             # For offline training, the base model has no layers.
             # Read the device from the lm_head instead.
-            device = self.lm_head.weight.device
+            device = getattr(self.lm_head, "weight", self.model.embed_tokens.weight).device
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py (3)

89-91: Set eval mode for deterministic, faster inference.

Call model.eval() after loading.

-    model = AutoModel.from_pretrained(args.model, torch_dtype="auto", device_map="auto")
+    model = AutoModel.from_pretrained(args.model, torch_dtype="auto", device_map="auto")
+    model.eval()

78-88: This script is synchronous; drop asyncio for simplicity.

No awaits inside main. Convert to a regular function and remove asyncio import/usage.

@@
-import asyncio
@@
-async def main(args: argparse.Namespace) -> None:
+def main(args: argparse.Namespace) -> None:
@@
 if __name__ == "__main__":
     cli_args = parse_args()
-    asyncio.run(main(cli_args))
+    main(cli_args)

Also applies to: 18-21, 166-168


115-123: Separate “too short” and “too long” counters/logs for clarity.

The variable name num_skipped_too_long also counts short samples (<=10). Track them separately to make logs actionable.

@@
-    num_skipped_too_long = 0
+    num_skipped_too_long = 0
+    num_skipped_too_short = 0
@@
-        if num_input_tokens <= 10 or num_input_tokens > args.max_seq_len:
-            num_skipped_too_long += 1
+        if num_input_tokens <= 10:
+            num_skipped_too_short += 1
+            continue
+        if num_input_tokens > args.max_seq_len:
+            num_skipped_too_long += 1
             continue
@@
-    if num_skipped_too_long > 0:
+    if num_skipped_too_short > 0:
+        print(f"Skipped {num_skipped_too_short} conversations for being too short (<=10 tokens).")
+    if num_skipped_too_long > 0:
         print(f"Skipped {num_skipped_too_long} conversations due to length constraints.")

Also applies to: 153-163

examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh (1)

25-36: Add shebang/safety flags, quote vars, and ensure cleanup on interruption.

Harden the runner and keep temp files tidy even on Ctrl‑C.

+#!/usr/bin/env bash
+set -euo pipefail
+
 INPUT_FILE=synthetic_conversations/daring-anteater.jsonl
 OUTPUT_DIR=/mnt/md0/eagle-hidden-states/llama1b/daring_anteater/
 
-split -n l/8 --numeric-suffixes=0 -d --additional-suffix=.jsonl $INPUT_FILE /tmp/part-
+split -n l/8 --numeric-suffixes=0 -d --additional-suffix=.jsonl "$INPUT_FILE" /tmp/part-
+trap 'rm -f /tmp/part-*.jsonl' EXIT
 
 for i in $(seq 0 7)
 do
-CUDA_VISIBLE_DEVICES=$i python3 collect_hidden_states/compute_hidden_states_hf.py --model meta-llama/Llama-3.2-1B-Instruct --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR &
+CUDA_VISIBLE_DEVICES="$i" python3 collect_hidden_states/compute_hidden_states_hf.py \
+  --model meta-llama/Llama-3.2-1B-Instruct \
+  --input-file "/tmp/part-0${i}.jsonl" \
+  --output-dir "$OUTPUT_DIR" &
 done
 wait
 
-rm /tmp/part-*.jsonl
+# Files are removed by trap on EXIT

Also applies to: 1-15

examples/speculative_decoding/eagle_utils.py (2)

214-220: Load offline tensors onto CPU explicitly to avoid accidental CUDA deserialization.

If any .pt file was saved with CUDA tensors, torch.load will attempt to place them on GPU and can OOM. Force CPU map_location.

-        offline_data = torch.load(offline_file_path)
+        offline_data = torch.load(offline_file_path, map_location="cpu")

353-380: Batching: align hidden-state sequence length with input_ids to prevent downstream surprises.

You pad hidden states to max_hs_length independently of base_batch padding. If these diverge (shouldn’t, but can in edge cases), training code may assume equal lengths. Consider asserting equality or padding HS to base_batch length.

-        max_hs_length = max(item["base_model_hidden_states"].shape[0] for item in features)
+        max_hs_length = max(item["base_model_hidden_states"].shape[0] for item in features)
+        # Optional safety: ensure HS length matches token length
+        base_seq_len = base_batch["input_ids"].shape[1]
+        assert max_hs_length == base_seq_len, (
+            f"Hidden-state length ({max_hs_length}) != token length ({base_seq_len})."
+        )
examples/speculative_decoding/collect_hidden_states/run_send_conversations.sh (2)

1-1: Add a shebang to satisfy shellcheck and ensure correct interpreter.

Without a shebang, shells may pick inconsistent interpreters and SC2148 is triggered.

Apply this diff:

+#!/usr/bin/env bash
+set -euo pipefail

19-23: Make the script path-robust and fix the CLI flag name.

  • Using a hardcoded relative path will break if users cd into this directory.
  • The commented flag name doesn't match the Python CLI (--debug-max-num-conversations).

Apply this diff:

-python3 collect_hidden_states/send_conversations_for_hiddens.py \
+SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+python3 "${SCRIPT_DIR}/send_conversations_for_hiddens.py" \
   --model meta-llama/Llama-3.2-1B-Instruct \
   --input-file synthetic_conversations/mtbench.jsonl \
   --output-dir /mnt/md0/eagle-hidden-states/llama1b/mtbench/
-# --debug-max-num-conversations-per-split 1000
+# --debug-max-num-conversations 1000
examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh (1)

1-1: Add a shebang to make the script executable and lint-clean.

+#!/usr/bin/env bash
+set -euo pipefail
examples/speculative_decoding/prepare_input_conversations/add_ultrachat.py (2)

41-47: Accept both --output-split-name and --output-split for CLI consistency.

The example script uses --output-split; add it as an alias to avoid UX breakage.

-    parser.add_argument(
-        "--output-split-name",
+    parser.add_argument(
+        "--output-split-name", "--output-split",
         type=str,
         default="ultrachat",
         help=dataset_splits_explanation("ultrachat"),
     )

58-79: Nit: async not needed; code is synchronous.

Optional: drop async and asyncio.run for simplicity, or await any future async I/O.

examples/speculative_decoding/prepare_input_conversations/add_mtbench.py (2)

45-51: Support --output-split alias to match example usage.

-    parser.add_argument(
-        "--output-split-name",
+    parser.add_argument(
+        "--output-split-name", "--output-split",
         type=str,
         default="mtbench",
         help=dataset_splits_explanation("mtbench"),
     )

89-92: Use the full message list when hashing the conversation ID for consistency.

Other scripts hash the normalized message list; hashing only the prompt string is inconsistent.

-        prompt_id = f"mtbench-{entry['question_id']:03}_" + id_for_conversation(prompt)
-        input_conversations.append(
-            {"conversation_id": prompt_id, "conversations": [{"role": "user", "content": prompt}]}
-        )
+        msgs = [{"role": "user", "content": prompt}]
+        prompt_id = f"mtbench-{entry['question_id']:03}_" + id_for_conversation(msgs)
+        input_conversations.append(
+            {"conversation_id": prompt_id, "conversations": msgs}
+        )
examples/speculative_decoding/prepare_input_conversations/add_daring_anteater.py (2)

41-46: Accept --output-split alias to align with examples.

-    parser.add_argument(
-        "--output-dir",
+    parser.add_argument(
+        "--output-dir",
         type=Path,
         default=Path("input_conversations/"),
         help="Path to save the conversations file(s) into. Default is 'input_conversations/'.",
     )

And:

-    parser.add_argument(
-        "--output-split-name",
+    parser.add_argument(
+        "--output-split-name", "--output-split",
         type=str,
         default="daring-anteater",
         help=dataset_splits_explanation("daring-anteater"),
     )

87-90: Skip empty conversations to avoid downstream failures.

Downstream scripts expect non-empty conversations; append only if messages were extracted.

-            input_conversations.append(
-                {"conversation_id": prompt_id, "conversations": processed_conversations}
-            )
+            if processed_conversations:
+                input_conversations.append(
+                    {"conversation_id": prompt_id, "conversations": processed_conversations}
+                )
examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py (4)

104-107: Close the OpenAI client to avoid connection leaks.

Use a context manager to ensure the underlying httpx client is closed.

-    client: AsyncOpenAI = AsyncOpenAI(
-        api_key=args.openai_api_key,
-        base_url=args.base_url,
-    )
+    async with AsyncOpenAI(api_key=args.openai_api_key, base_url=args.base_url) as client:

Also indent the subsequent usage accordingly.


175-186: Catch OpenAI client errors instead of httpx directly; keep a broad fallback.

The OpenAI SDK wraps HTTP errors; catching httpx.HTTPStatusError likely won’t trigger. Prefer openai.OpenAIError (and optionally BadRequestError).

-        except httpx.HTTPStatusError as e:
-            print(f"HTTP error for conversation {conversation_id}: {e}")
-            num_error += 1
-            continue
-        except openai.BadRequestError:
+        except openai.BadRequestError:
             # Most likely the conversation is too long, ignore
             num_too_long += 1
             continue
+        except openai.OpenAIError as e:
+            print(f"OpenAI client error for conversation {conversation_id}: {e}")
+            num_error += 1
+            continue

143-153: Meta file lifecycle: leave fewer footguns.

If a conversation is skipped early (length), the temp meta file persists until the next loop. You already guard at the top, but proactively removing on skip (see earlier diff) reduces surprises if users parallelize.


188-193: Minor: redundant continue.

The loop proceeds to the next iteration anyway.

-        num_success += 1
-        continue
+        num_success += 1
examples/speculative_decoding/prepare_input_conversations/add_sharegpt.py (1)

62-85: Make parsing resilient to unknown roles and non‑string values.

Avoid aborting on unexpected roles and guard against non‑string "value" fields.

Apply:

 def parse_sharegpt_conversation(sharegpt_conv: dict) -> list[dict] | None:
@@
-        elif turn.get("from") == "bing":
-            # Bing conversations are skipped for training, omit it
-            return None
+        elif turn.get("from") == "bing":
+            # Bing conversations are skipped for training, omit it
+            return None
         else:
-            err_msg = f"Unknown role in conversation: {turn.get('from')}"
-            raise ValueError(err_msg)
+            # Skip unknown roles rather than abort the whole run
+            print(f"Warning: Unknown role in conversation: {turn.get('from')}, skipping turn.")
+            continue
-
-        value = turn.get("value", "").strip()
-        if value:
-            msgs.append({"role": role, "content": value})
+        raw = turn.get("value", "")
+        if not isinstance(raw, str):
+            # Ignore non-string payloads
+            continue
+        value = raw.strip()
+        if value:
+            msgs.append({"role": role, "content": value})
examples/speculative_decoding/prepare_input_conversations/utils.py (3)

101-108: Float equality is brittle; allow small tolerance.

Avoid false negatives due to FP rounding.

Apply:

-    if train_ratio + val_ratio + test_ratio != 1.0:
-        msg = "Ratios must sum to 1.0"
+    total = train_ratio + val_ratio + test_ratio
+    if abs(total - 1.0) > 1e-9:
+        msg = f"Ratios must sum to 1.0 (got {total})"
         raise ValueError(msg)

113-116: Avoid mutating global RNG state when shuffling.

Use a local Random instance to keep determinism without side effects.

Apply:

-    if shuffle:
-        random.seed(seed)
-        random.shuffle(conversations)
+    if shuffle:
+        rng = random.Random(seed)
+        rng.shuffle(conversations)

155-170: Help text prints “%%” instead of “%”.

Use single percent signs so argparse help is readable.

Apply:

-        - 'mix': Conversations will be randomly mixed and distributed into
-            'train' (80%%), 'val' (10%%), and 'test' (10%%) splits.
-        - 'mix_test': Conversations will be randomly mixed and distributed into
-            'val' (50%%) and 'test' (50%%) splits.
+        - 'mix': Conversations will be randomly mixed and distributed into
+            'train' (80%), 'val' (10%), and 'test' (10%) splits.
+        - 'mix_test': Conversations will be randomly mixed and distributed into
+            'val' (50%) and 'test' (50%) splits.
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9aedfdf and f92be76.

📒 Files selected for processing (21)
  • examples/speculative_decoding/.gitignore (1 hunks)
  • examples/speculative_decoding/collect_hidden_states/__init__.py (1 hunks)
  • examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py (1 hunks)
  • examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens.sh (1 hunks)
  • examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh (1 hunks)
  • examples/speculative_decoding/collect_hidden_states/run_send_conversations.sh (1 hunks)
  • examples/speculative_decoding/collect_hidden_states/sample_hidden_states.py (1 hunks)
  • examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py (1 hunks)
  • examples/speculative_decoding/eagle_utils.py (4 hunks)
  • examples/speculative_decoding/launch.sh (0 hunks)
  • examples/speculative_decoding/launch_train.sh (4 hunks)
  • examples/speculative_decoding/main.py (4 hunks)
  • examples/speculative_decoding/prepare_input_conversations/__init__.py (1 hunks)
  • examples/speculative_decoding/prepare_input_conversations/add_daring_anteater.py (1 hunks)
  • examples/speculative_decoding/prepare_input_conversations/add_mtbench.py (1 hunks)
  • examples/speculative_decoding/prepare_input_conversations/add_sharegpt.py (1 hunks)
  • examples/speculative_decoding/prepare_input_conversations/add_ultrachat.py (1 hunks)
  • examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh (1 hunks)
  • examples/speculative_decoding/prepare_input_conversations/utils.py (1 hunks)
  • examples/speculative_decoding/train_eagle3_and_export.sh (2 hunks)
  • modelopt/torch/speculative/plugins/transformers.py (4 hunks)
💤 Files with no reviewable changes (1)
  • examples/speculative_decoding/launch.sh
🧰 Additional context used
🧬 Code graph analysis (12)
examples/speculative_decoding/prepare_input_conversations/utils.py (2)
modelopt/torch/utils/random.py (2)
  • random (59-61)
  • shuffle (148-150)
modelopt/torch/_deploy/_runtime/common.py (1)
  • write_bytes (65-67)
examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh (1)
examples/speculative_decoding/main.py (1)
  • train (117-271)
examples/speculative_decoding/collect_hidden_states/sample_hidden_states.py (1)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py (2)
  • parse_args (28-75)
  • main (78-163)
examples/speculative_decoding/prepare_input_conversations/add_sharegpt.py (1)
examples/speculative_decoding/prepare_input_conversations/utils.py (4)
  • dataset_splits_explanation (155-170)
  • download_file (26-34)
  • id_for_conversation (37-41)
  • update_dataset_file_with_conversations (125-152)
examples/speculative_decoding/prepare_input_conversations/add_mtbench.py (1)
examples/speculative_decoding/prepare_input_conversations/utils.py (4)
  • dataset_splits_explanation (155-170)
  • download_file (26-34)
  • id_for_conversation (37-41)
  • update_dataset_file_with_conversations (125-152)
examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py (2)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py (2)
  • parse_args (28-75)
  • main (78-163)
examples/speculative_decoding/collect_hidden_states/sample_hidden_states.py (2)
  • parse_args (25-46)
  • main (49-84)
examples/speculative_decoding/prepare_input_conversations/add_ultrachat.py (1)
examples/speculative_decoding/prepare_input_conversations/utils.py (3)
  • dataset_splits_explanation (155-170)
  • id_for_conversation (37-41)
  • update_dataset_file_with_conversations (125-152)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py (2)
examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py (2)
  • parse_args (30-90)
  • main (93-206)
examples/speculative_decoding/collect_hidden_states/sample_hidden_states.py (2)
  • parse_args (25-46)
  • main (49-84)
examples/speculative_decoding/prepare_input_conversations/add_daring_anteater.py (1)
examples/speculative_decoding/prepare_input_conversations/utils.py (3)
  • dataset_splits_explanation (155-170)
  • id_for_conversation (37-41)
  • update_dataset_file_with_conversations (125-152)
examples/speculative_decoding/eagle_utils.py (1)
modelopt/torch/utils/logging.py (1)
  • print_rank_0 (92-95)
examples/speculative_decoding/main.py (1)
examples/speculative_decoding/eagle_utils.py (1)
  • make_eagle_supervised_data_module (238-312)
modelopt/torch/speculative/plugins/transformers.py (2)
modelopt/torch/speculative/eagle/eagle_model.py (1)
  • EagleModel (23-51)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
  • _set_default_aux_hidden_state_layers (682-694)
🪛 Shellcheck (0.10.0)
examples/speculative_decoding/collect_hidden_states/run_send_conversations.sh

[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.

(SC2148)

examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens.sh

[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.

(SC2148)

examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh

[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.

(SC2148)

examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh

[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.

(SC2148)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: linux
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (8)
examples/speculative_decoding/prepare_input_conversations/__init__.py (1)

1-16: LGTM — license header and package docstring look good.

examples/speculative_decoding/collect_hidden_states/__init__.py (1)

1-16: LGTM — clean package init with correct licensing.

examples/speculative_decoding/launch_train.sh (1)

127-136: Don't pass --data_path when using offline data; quote offline path.

File: examples/speculative_decoding/launch_train.sh Lines: 127-136

Passing an empty --data_path can break the CLI; quote paths to handle spaces and let the offline pipeline supply data.

 if [[ "$OFFLINE_DATA_PATH" != "" ]]; then
   if [[ ! -d "$OFFLINE_DATA_PATH" ]]; then
     echo "Offline data path $OFFLINE_DATA_PATH does not exist or is not a directory."
     exit 1
   else
-    OFFLINE_TRAINING_ARGS="--offline-data-path $OFFLINE_DATA_PATH"
+    OFFLINE_TRAINING_ARGS="--offline-data-path \"$OFFLINE_DATA_PATH\""
+    DATA_ARGS=""   # let offline pipeline supply data
   fi
 else
   OFFLINE_TRAINING_ARGS=""
+  DATA_ARGS="--data_path $DATA"
 fi
@@
-    --data_path $DATA \
+    ${DATA_ARGS} \

If main.py requires an explicit --offline-training flag, wire it here similarly.

Also applies to: 167-169

examples/speculative_decoding/eagle_utils.py (1)

238-307: Potential tokenization mismatch with offline data.

preprocess() removes the think‑stripping snippet from tokenizer.chat_template, but compute_hidden_states_hf.py originally didn’t. That will make input_ids differ and trip the shape check here. After adopting the matching replacement in compute_hidden_states_hf.py, please re‑verify that shapes match end‑to‑end on a sample.

examples/speculative_decoding/main.py (1)

236-241: AR validation: early return when disabled looks good.

Short‑circuit is correct and won’t interfere with control flow. LGTM.

examples/speculative_decoding/prepare_input_conversations/add_sharegpt.py (2)

87-101: Ensure runtime deps are documented/installed (aiohttp, tqdm).

Add these to an examples requirements file or docs so first‑time runs don’t fail.


110-118: Conversation ID construction looks solid.

Stable hash + source id prefixing should dedupe reliably.

examples/speculative_decoding/prepare_input_conversations/utils.py (1)

37-42: LGTM: stable conversation hashing.

Deterministic SHA‑256 over normalized JSON is appropriate for deduping.

@benchislett benchislett force-pushed the bchislett/offline-eagle-training branch from f92be76 to f74bf59 Compare September 17, 2025 14:24
@benchislett
Copy link
Author

/ok to test f74bf59

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (5)
examples/speculative_decoding/main.py (1)

188-190: Good: honor EagleConfig.eagle_offline per prior feedback.
This aligns with the earlier request to use the existing config flag.

examples/speculative_decoding/prepare_input_conversations/utils.py (2)

26-36: Add a network timeout and avoid multi-context in one line for clarity.

Prevents indefinite hangs and improves readability; keep parent dir creation.

Apply:

 async def download_file(url: str, destination: Path) -> None:
     """Download a file from a URL to a specified destination."""
     destination.parent.mkdir(parents=True, exist_ok=True)
-    async with aiohttp.ClientSession() as session, session.get(url) as response:
-        if response.status != 200:
-            msg = f"Failed to download {url}: {response.status}"
-            raise RuntimeError(msg)
-        content = await response.read()
-        destination.write_bytes(content)
-        print(f"Downloaded {url} to {destination}")
+    timeout = aiohttp.ClientTimeout(total=600)
+    async with aiohttp.ClientSession(timeout=timeout) as session:
+        async with session.get(url) as response:
+            if response.status != 200:
+                msg = f"Failed to download {url}: {response.status}"
+                raise RuntimeError(msg)
+            content = await response.read()
+            destination.write_bytes(content)
+            print(f"Downloaded {url} to {destination}")

45-86: Prevent duplicate IDs within the same update and validate required keys.

Currently, duplicates inside the provided conversations list can be appended multiple times because existing_ids isn’t updated as you add. Also, missing "conversations" will raise KeyError later; validate early.

Apply:

 def add_conversations_to_split(conversations: list, dataset_dir: Path, split: str) -> None:
@@
-    # Open the dataset file for the specified split, or create it if it doesn't exist
-    dataset_file = dataset_dir / f"{split}.jsonl"
+    # Ensure output directory exists and open/create split file
+    dataset_dir.mkdir(parents=True, exist_ok=True)
+    dataset_file = dataset_dir / f"{split}.jsonl"
@@
-    existing_ids = {entry["conversation_id"] for entry in all_conversations}
+    existing_ids = {entry["conversation_id"] for entry in all_conversations}
     num_new_entries = 0
     num_duplicates = 0
     for entry in conversations:
-        if entry.get("conversation_id") is None:
+        entry_id = entry.get("conversation_id")
+        if entry_id is None:
             raise ValueError("Each conversation must have a 'conversation_id' field.")
-        if entry["conversation_id"] not in existing_ids:
+        if "conversations" not in entry:
+            raise ValueError("Each conversation must have a 'conversations' field.")
+        if entry_id not in existing_ids:
             all_conversations.append(
                 {
-                    "conversation_id": entry["conversation_id"],
+                    "conversation_id": entry_id,
                     "conversations": entry["conversations"],
                 }
             )
             num_new_entries += 1
+            existing_ids.add(entry_id)
         else:
             num_duplicates += 1
@@
-    dataset_dir.mkdir(parents=True, exist_ok=True)
     with dataset_file.open("w", encoding="utf-8") as f:
         for entry in all_conversations:
             f.write(json.dumps(entry, ensure_ascii=False) + "\n")
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py (2)

27-29: Unify REMOVE_THINK_CHAT_TEMPLATE with training and guard None chat_template.

Avoid divergence from training preprocessing and prevent AttributeError when chat_template is None.

-from transformers import AutoModel, AutoTokenizer
+from transformers import AutoModel, AutoTokenizer
+try:
+    # Keep in sync with training; import if available when run from examples/speculative_decoding
+    from eagle_utils import REMOVE_THINK_CHAT_TEMPLATE
+except Exception:
+    # Fallback for standalone runs; ensure value matches eagle_utils.REMOVE_THINK_CHAT_TEMPLATE
+    REMOVE_THINK_CHAT_TEMPLATE = (
+        "{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}"
+    )
@@
-REMOVE_THINK_CHAT_TEMPLATE = (
-    "{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}"
-)
@@
-    tokenizer = AutoTokenizer.from_pretrained(args.model)
+    tokenizer = AutoTokenizer.from_pretrained(args.model)
     if tokenizer.pad_token is None:
         tokenizer.pad_token = tokenizer.eos_token
-    tokenizer.chat_template = tokenizer.chat_template.replace(REMOVE_THINK_CHAT_TEMPLATE, "")
+    if getattr(tokenizer, "chat_template", None):
+        tokenizer.chat_template = tokenizer.chat_template.replace(
+            REMOVE_THINK_CHAT_TEMPLATE, ""
+        )

Also applies to: 96-100


138-148: Clamp/deduplicate auxiliary layer indices; avoid out‑of‑range/duplicates on small models.

Current logic can pick invalid indices (e.g., 2 when num_hidden_layers < 3).

-            # Extract hidden states from layers with index (2, N/2, N-3), and the output hidden states
+            # Extract hidden states from early/mid/late layers; clamp within [0, N-1]
             hidden_states = outputs.hidden_states
-            selected_layer_indices = [
-                2,
-                max(0, num_hidden_layers // 2),
-                max(1, num_hidden_layers - 3),
-            ]
-            selected_layer_indices = sorted(set(selected_layer_indices))
+            candidates = [2, num_hidden_layers // 2, num_hidden_layers - 3]
+            selected_layer_indices = sorted(
+                {i for i in candidates if 0 <= i <= num_hidden_layers - 1}
+            )
+            if not selected_layer_indices:
+                selected_layer_indices = [0]
             aux_hidden_states = torch.cat(
                 [hidden_states[i].squeeze(0).cpu() for i in selected_layer_indices], dim=-1
             )
🧹 Nitpick comments (12)
examples/speculative_decoding/main.py (3)

72-81: Annotate as Optional to match None default.

offline_data_path defaults to None but is typed as str. Make it str | None for consistency with eval_data_path and type checkers.

Apply this diff:

-    offline_data_path: str = field(
+    offline_data_path: str | None = field(

144-145: Make the offline switch robust to empty strings.

is not None treats "" as offline. Prefer truthiness.

Apply this diff:

-use_offline_training = data_args.offline_data_path is not None
+use_offline_training = bool(data_args.offline_data_path)

150-159: num_hidden_layers=0: add safe fallback and always set num_orig_hidden_layers.

Keep your space‑saving default but harden for models that can’t instantiate with 0 layers, and ensure num_orig_hidden_layers is set even when resuming or the override path changes.

Apply this diff:

-        model_kwargs = {"num_hidden_layers": 0} if use_offline_training else {}
-        model = transformers.AutoModelForCausalLM.from_pretrained(
-            model_args.model_name_or_path, torch_dtype="auto", **model_kwargs
-        )
-        if use_offline_training:
-            # When doing offline training, we need to set num_hidden_layers
-            # since we override it when loading the model for space savings
-            model_config = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path)
-            model.config.num_orig_hidden_layers = model_config.num_hidden_layers
+        model_kwargs = {"num_hidden_layers": 0} if use_offline_training else {}
+        try:
+            model = transformers.AutoModelForCausalLM.from_pretrained(
+                model_args.model_name_or_path, torch_dtype="auto", **model_kwargs
+            )
+        except Exception as e:
+            if use_offline_training:
+                print_rank_0("num_hidden_layers=0 failed; falling back to 1 for offline training.")
+                model = transformers.AutoModelForCausalLM.from_pretrained(
+                    model_args.model_name_or_path, torch_dtype="auto", num_hidden_layers=1
+                )
+            else:
+                raise
+        if use_offline_training and not hasattr(model.config, "num_orig_hidden_layers"):
+            # Record original depth for plugins that need it.
+            base_cfg = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path)
+            model.config.num_orig_hidden_layers = base_cfg.num_hidden_layers

Please run a quick smoke test on at least one model per family you target (e.g., Llama, Mistral, Qwen2, Phi‑3) to confirm no fallback is triggered unexpectedly and training proceeds.

examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh (2)

1-1: Add a shebang to fix ShellCheck SC2148 and enable script execution.

Without a shebang, shells may interpret this file incorrectly and ShellCheck flags SC2148.

Apply:

+#!/usr/bin/env bash
 # SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

16-25: Fix relative paths; make script location-agnostic; add strict mode and ensure output dir.

The script lives in prepare_input_conversations/, but calls into prepare_input_conversations/*.py again, which breaks when run from repo root or this directory. Also add set -euo pipefail and create a data dir.

Apply:

-# Example script to prepare a dataset of prompts for generation
-# Lines in this script can be uncommented to include specific datasets/splits in the prompt dataset.
+# Example script to prepare a dataset of prompts for generation
+# Lines in this script can be uncommented to include specific datasets/splits in the prompt dataset.
+set -euo pipefail
+SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+# Optional: change DATASET_DIR or pass via env
+DATASET_DIR="${DATASET_DIR:-${SCRIPT_DIR}/data}"
+mkdir -p "${DATASET_DIR}"

-python3 prepare_input_conversations/add_daring_anteater.py --output-split-name train
-# python3 prepare_input_conversations/add_sharegpt.py --output-split-name train
-# python3 prepare_input_conversations/add_ultrachat.py --ultrachat-split train_sft --output-split-name train
-# python3 prepare_input_conversations/add_ultrachat.py --ultrachat-split train_gen --output-split-name train
-# python3 prepare_input_conversations/add_ultrachat.py --ultrachat-split test_sft --output-split-name mix_test
-# python3 prepare_input_conversations/add_ultrachat.py --ultrachat-split test_gen --output-split-name mix_test
-python3 prepare_input_conversations/add_mtbench.py --output-split-name mix_test
+python3 "${SCRIPT_DIR}/add_daring_anteater.py" --output-split-name train ${DATASET_DIR:+--dataset-dir "${DATASET_DIR}"}
+# python3 "${SCRIPT_DIR}/add_sharegpt.py" --output-split-name train ${DATASET_DIR:+--dataset-dir "${DATASET_DIR}"}
+# python3 "${SCRIPT_DIR}/add_ultrachat.py" --ultrachat-split train_sft --output-split-name train ${DATASET_DIR:+--dataset-dir "${DATASET_DIR}"}
+# python3 "${SCRIPT_DIR}/add_ultrachat.py" --ultrachat-split train_gen --output-split-name train ${DATASET_DIR:+--dataset-dir "${DATASET_DIR}"}
+# python3 "${SCRIPT_DIR}/add_ultrachat.py" --ultrachat-split test_sft --output-split-name mix_test ${DATASET_DIR:+--dataset-dir "${DATASET_DIR}"}
+# python3 "${SCRIPT_DIR}/add_ultrachat.py" --ultrachat-split test_gen --output-split-name mix_test ${DATASET_DIR:+--dataset-dir "${DATASET_DIR}"}
+python3 "${SCRIPT_DIR}/add_mtbench.py" --output-split-name mix_test ${DATASET_DIR:+--dataset-dir "${DATASET_DIR}"} 

Note: If the add_* CLIs don’t support --dataset-dir, the parameter will be omitted; the mkdir remains safe.

examples/speculative_decoding/prepare_input_conversations/utils.py (1)

93-125: Make ratio check tolerant to FP error and avoid mutating global RNG state.

isclose avoids false negatives; using a local Random(seed) preserves global RNG for callers.

Apply:

+import math
@@
-    if train_ratio + val_ratio + test_ratio != 1.0:
+    if not math.isclose(train_ratio + val_ratio + test_ratio, 1.0, rel_tol=0.0, abs_tol=1e-9):
         msg = "Ratios must sum to 1.0"
         raise ValueError(msg)
@@
-    if shuffle:
-        random.seed(seed)
-        random.shuffle(conversations)
+    if shuffle:
+        random.Random(seed).shuffle(conversations)
examples/speculative_decoding/collect_hidden_states/run_send_conversations.sh (1)

1-24: Add shebang/safe flags, fix debug flag name, and quote paths.

Without a shebang, running as an executable fails; ShellCheck SC2148 also flags this. The commented flag name doesn’t exist in the CLI (should be --debug-max-num-conversations). Also make invocation path-robust and quote args.

+#!/usr/bin/env bash
+set -euo pipefail
+
+# Resolve script directory to make path robust when invoked from anywhere.
+script_dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+
-python3 collect_hidden_states/send_conversations_for_hiddens.py \
-  --model meta-llama/Llama-3.2-1B-Instruct \
-  --input-file synthetic_conversations/mtbench.jsonl \
-  --output-dir /mnt/md0/eagle-hidden-states/llama1b/mtbench/
-# --debug-max-num-conversations-per-split 1000
+python3 "$script_dir/send_conversations_for_hiddens.py" \
+  --model "meta-llama/Llama-3.2-1B-Instruct" \
+  --input-file "synthetic_conversations/mtbench.jsonl" \
+  --output-dir "/mnt/md0/eagle-hidden-states/llama1b/mtbench/"
+# Optional: limit processed conversations during debugging
+# --debug-max-num-conversations 1000
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py (3)

93-95: Set eval() before inference.

Minor but standard to disable dropout and ensure deterministic behavior.

-    model = AutoModel.from_pretrained(args.model, torch_dtype="auto", device_map="auto")
+    model = AutoModel.from_pretrained(args.model, torch_dtype="auto", device_map="auto")
+    model.eval()

103-127: Counter name is misleading; it tracks too‑short and too‑long.

Tiny naming nit; adjust for clarity.

-    num_skipped_too_long = 0
+    num_filtered_by_length = 0
@@
-        if num_input_tokens <= 10 or num_input_tokens > args.max_seq_len:
-            num_skipped_too_long += 1
+        if num_input_tokens <= 10 or num_input_tokens > args.max_seq_len:
+            num_filtered_by_length += 1
             continue
@@
-    if num_skipped_too_long > 0:
-        print(f"Skipped {num_skipped_too_long} conversations due to length constraints.")
+    if num_filtered_by_length > 0:
+        print(f"Skipped {num_filtered_by_length} conversations due to length constraints.")

Also applies to: 163-164


82-83: Async isn’t used; simplify to synchronous main.

Removes unnecessary asyncio plumbing.

-async def main(args: argparse.Namespace) -> None:
+def main(args: argparse.Namespace) -> None:
@@
-if __name__ == "__main__":
-    cli_args = parse_args()
-    asyncio.run(main(cli_args))
+if __name__ == "__main__":
+    cli_args = parse_args()
+    main(cli_args)

Also applies to: 176-178

examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens.sh (1)

1-23: Add shebang/safe flags, quote args, and create output dir.

Prevents exec failures (SC2148) and path issues; ensures output dir exists.

+#!/usr/bin/env bash
+set -euo pipefail
+
+# Resolve script directory
+script_dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+
+# Output directory (edit as needed)
+OUTPUT_DIR="/mnt/md0/eagle-hidden-states/llama1b/daring_anteater/"
+mkdir -p "$OUTPUT_DIR"
+
-python3 collect_hidden_states/compute_hidden_states_hf.py \
-  --model meta-llama/Llama-3.2-1B-Instruct \
-  --input-file synthetic_conversations/daring-anteater.jsonl \
-  --output-dir /mnt/md0/eagle-hidden-states/llama1b/daring_anteater/
+python3 "$script_dir/compute_hidden_states_hf.py" \
+  --model "meta-llama/Llama-3.2-1B-Instruct" \
+  --input-file "synthetic_conversations/daring-anteater.jsonl" \
+  --output-dir "$OUTPUT_DIR"
+# Optional:
+# --max-seq-len 3072
examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh (1)

16-36: Harden DP runner: shebang/safe flags, temp workspace, quoting, and mkdir.

Avoids /tmp collisions, improves safety, and makes paths robust.

+#!/usr/bin/env bash
+set -euo pipefail
+
-INPUT_FILE=synthetic_conversations/daring-anteater.jsonl
-OUTPUT_DIR=/mnt/md0/eagle-hidden-states/llama1b/daring_anteater/
+INPUT_FILE="synthetic_conversations/daring-anteater.jsonl"
+OUTPUT_DIR="/mnt/md0/eagle-hidden-states/llama1b/daring_anteater/"
+mkdir -p "$OUTPUT_DIR"
+
-split -n l/8 --numeric-suffixes=0 -d --additional-suffix=.jsonl $INPUT_FILE /tmp/part-
+tmpdir="$(mktemp -d)"
+trap 'rm -rf "$tmpdir"' EXIT
+split -n l/8 --numeric-suffixes=0 -d --additional-suffix=.jsonl "$INPUT_FILE" "$tmpdir/part-"
 
 for i in $(seq 0 7)
 do
-CUDA_VISIBLE_DEVICES=$i python3 collect_hidden_states/compute_hidden_states_hf.py --model meta-llama/Llama-3.2-1B-Instruct --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR &
+  CUDA_VISIBLE_DEVICES="$i" python3 collect_hidden_states/compute_hidden_states_hf.py \
+    --model "meta-llama/Llama-3.2-1B-Instruct" \
+    --input-file "$tmpdir/part-0${i}.jsonl" \
+    --output-dir "$OUTPUT_DIR" &
 done
 wait
-
-rm /tmp/part-*.jsonl
+# Temporary files cleaned via trap
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f92be76 and f74bf59.

📒 Files selected for processing (21)
  • examples/speculative_decoding/.gitignore (1 hunks)
  • examples/speculative_decoding/collect_hidden_states/__init__.py (1 hunks)
  • examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py (1 hunks)
  • examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens.sh (1 hunks)
  • examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh (1 hunks)
  • examples/speculative_decoding/collect_hidden_states/run_send_conversations.sh (1 hunks)
  • examples/speculative_decoding/collect_hidden_states/sample_hidden_states.py (1 hunks)
  • examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py (1 hunks)
  • examples/speculative_decoding/eagle_utils.py (4 hunks)
  • examples/speculative_decoding/launch.sh (0 hunks)
  • examples/speculative_decoding/launch_train.sh (4 hunks)
  • examples/speculative_decoding/main.py (4 hunks)
  • examples/speculative_decoding/prepare_input_conversations/__init__.py (1 hunks)
  • examples/speculative_decoding/prepare_input_conversations/add_daring_anteater.py (1 hunks)
  • examples/speculative_decoding/prepare_input_conversations/add_mtbench.py (1 hunks)
  • examples/speculative_decoding/prepare_input_conversations/add_sharegpt.py (1 hunks)
  • examples/speculative_decoding/prepare_input_conversations/add_ultrachat.py (1 hunks)
  • examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh (1 hunks)
  • examples/speculative_decoding/prepare_input_conversations/utils.py (1 hunks)
  • examples/speculative_decoding/train_eagle3_and_export.sh (2 hunks)
  • modelopt/torch/speculative/plugins/transformers.py (4 hunks)
💤 Files with no reviewable changes (1)
  • examples/speculative_decoding/launch.sh
🚧 Files skipped from review as they are similar to previous changes (13)
  • examples/speculative_decoding/.gitignore
  • examples/speculative_decoding/collect_hidden_states/init.py
  • examples/speculative_decoding/prepare_input_conversations/add_sharegpt.py
  • examples/speculative_decoding/launch_train.sh
  • examples/speculative_decoding/prepare_input_conversations/init.py
  • examples/speculative_decoding/collect_hidden_states/sample_hidden_states.py
  • examples/speculative_decoding/prepare_input_conversations/add_daring_anteater.py
  • examples/speculative_decoding/prepare_input_conversations/add_ultrachat.py
  • examples/speculative_decoding/train_eagle3_and_export.sh
  • examples/speculative_decoding/prepare_input_conversations/add_mtbench.py
  • modelopt/torch/speculative/plugins/transformers.py
  • examples/speculative_decoding/eagle_utils.py
  • examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: benchislett
PR: NVIDIA/TensorRT-Model-Optimizer#300
File: examples/speculative_decoding/main.py:150-159
Timestamp: 2025-09-17T14:14:44.950Z
Learning: In examples/speculative_decoding/main.py, setting num_hidden_layers=0 for offline training has been successfully tested by benchislett with past models and works in their use case for EAGLE3 offline training.
📚 Learning: 2025-09-17T14:14:44.950Z
Learnt from: benchislett
PR: NVIDIA/TensorRT-Model-Optimizer#300
File: examples/speculative_decoding/main.py:150-159
Timestamp: 2025-09-17T14:14:44.950Z
Learning: In examples/speculative_decoding/main.py, setting num_hidden_layers=0 for offline training has been successfully tested by benchislett with past models and works in their use case for EAGLE3 offline training.

Applied to files:

  • examples/speculative_decoding/main.py
🧬 Code graph analysis (4)
examples/speculative_decoding/main.py (1)
examples/speculative_decoding/eagle_utils.py (1)
  • make_eagle_supervised_data_module (238-312)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py (2)
examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py (2)
  • parse_args (30-90)
  • main (93-206)
examples/speculative_decoding/collect_hidden_states/sample_hidden_states.py (2)
  • parse_args (25-46)
  • main (49-84)
examples/speculative_decoding/prepare_input_conversations/utils.py (1)
modelopt/torch/utils/random.py (2)
  • random (59-61)
  • shuffle (148-150)
examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh (1)
examples/speculative_decoding/main.py (1)
  • train (117-271)
🪛 Shellcheck (0.10.0)
examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens.sh

[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.

(SC2148)

examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh

[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.

(SC2148)

examples/speculative_decoding/collect_hidden_states/run_send_conversations.sh

[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.

(SC2148)

examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh

[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.

(SC2148)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: linux
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (3)
examples/speculative_decoding/main.py (2)

230-230: Good: data module gated by use_offline_training.
Clear separation between online/offline loaders and collators.


237-238: Nice: allow disabling AR validation with non-positive steps.
Useful for offline runs without inline eval.

examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py (1)

27-29: Verify single source of truth for REMOVE_THINK_CHAT_TEMPLATE

rg returned "No files were searched" due to the glob filter; cannot confirm duplicates. Ensure no other copies exist and training/collection stay in lockstep — from the repo root run and paste outputs of:
rg -n --hidden --no-ignore 'REMOVE_THINK_CHAT_TEMPLATE'
rg -n --hidden --no-ignore ''
Location: examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py:27-29

Comment on lines +129 to +133
# Get hidden states
with torch.inference_mode():
outputs = model(input_ids=input_ids.to(model.device), output_hidden_states=True)
if num_hidden_layers is None:
num_hidden_layers = len(outputs.hidden_states) - 1
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix device handling with device_map='auto' (model.device may be missing or wrong).

Pass inputs to the embedding device; relying on model.device can raise or send tensors to CPU.

-        with torch.inference_mode():
-            outputs = model(input_ids=input_ids.to(model.device), output_hidden_states=True)
+        with torch.inference_mode():
+            emb_device = model.get_input_embeddings().weight.device
+            outputs = model(input_ids=input_ids.to(emb_device), output_hidden_states=True)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Get hidden states
with torch.inference_mode():
outputs = model(input_ids=input_ids.to(model.device), output_hidden_states=True)
if num_hidden_layers is None:
num_hidden_layers = len(outputs.hidden_states) - 1
# Get hidden states
with torch.inference_mode():
emb_device = model.get_input_embeddings().weight.device
outputs = model(input_ids=input_ids.to(emb_device), output_hidden_states=True)
if num_hidden_layers is None:
num_hidden_layers = len(outputs.hidden_states) - 1
🤖 Prompt for AI Agents
In
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py
around lines 129 to 133, the code moves inputs to model.device which can be
missing or incorrect when using device_map='auto'; instead obtain the embeddings
device (e.g. model.get_input_embeddings().weight.device) and move input_ids to
that device before calling the model; use that device value as a fallback if
model.device exists but prefer the embedding weight device to ensure tensors end
up on the correct device for sharded/auto-mapped models.

@kevalmorabia97
Copy link
Collaborator

@benchislett You need to sign your commits with an SSH key. Please take a look at https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md#%EF%B8%8F-signing-your-work

with args.input_file.open("r", encoding="utf-8") as f:
all_conversations.extend([json.loads(line) for line in f if line.strip()])

if any(not entry.get("conversation_id") for entry in all_conversations):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @benchislett, seems like there is a bug to this line:

  1. When some entry has a conversation_id=0, not entry.get("conversation_id") will return True, causing an error raised.
  2. Since we add the fallback below to allow no conversation_id, we should probably also remove this check here.

with args.input_file.open("r", encoding="utf-8") as f:
all_conversations.extend([json.loads(line) for line in f if line.strip()])

if any(not entry.get("conversation_id") for entry in all_conversations):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar as above

valid_entries = []
for entry in data_json:
conv_id = entry.get("conversation_id") or entry.get("id")
if not conv_id:
Copy link
Contributor

@h-guo18 h-guo18 Sep 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar as above, this will raise error when conv_id=0. We should probably use if conv_id is None instead.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Besides, in current line 273: when conversation_id=0, conv_id will evaluate to None since the left-hand-side of or is False. We want to probably do this instead:

conv_id = entry.get("conversation_id", entry.get("id")) 

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants